{ "cells": [ { "cell_type": "markdown", "id": "cMny8Ri7RvqC", "metadata": { "id": "cMny8Ri7RvqC" }, "source": [ "\n", "### **2. T-learner**\n", "The second learner is called T-learner, which denotes ``two learners\". Instead of fitting a single model to estimate the potential outcomes under both treatment and control groups, T-learner aims to learn different models for $\\mathbb{E}[R(1)|S]$ and $\\mathbb{E}[R(0)|S]$ separately, and finally combines them to obtain a final HTE estimator.\n", "\n", "Define the control response function as $\\mu_0(s)=\\mathbb{E}[R(0)|S=s]$, and the treatment response function as $\\mu_1(s)=\\mathbb{E}[R(1)|S=s]$. The algorithm of T-learner is summarized below:\n", "\n", "**Step 1:** Estimate $\\mu_0(s)$ and $\\mu_1(s)$ separately with any regression algorithms or supervised machine learning methods;\n", "\n", "**Step 2:** Estimate HTE by \n", "\\begin{equation*}\n", "\\hat{\\tau}_{\\text{T-learner}}(s)=\\hat\\mu_1(s)-\\hat\\mu_0(s).\n", "\\end{equation*}\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "eRpP5k9MBtzO", "metadata": { "id": "eRpP5k9MBtzO" }, "outputs": [], "source": [ "# import related packages\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt;\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.linear_model import LinearRegression\n", "from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL" ] }, { "cell_type": "markdown", "id": "XUu695Qrf61-", "metadata": { "id": "XUu695Qrf61-" }, "source": [ "### MovieLens Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "JhfJntzcVVy2", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "executionInfo": { "elapsed": 288, "status": "ok", "timestamp": 1676750101543, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "JhfJntzcVVy2", "outputId": "7fab8a7a-7cd9-445c-a005-9a6d1994a071" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingageDramaSci-Figender_Moccupation_academic/educatoroccupation_college/grad studentoccupation_executive/managerialoccupation_otheroccupation_technician/engineer
048.01193.04.025.01.00.01.00.01.00.00.00.0
148.0919.04.025.01.00.01.00.01.00.00.00.0
248.0527.05.025.01.00.01.00.01.00.00.00.0
348.01721.04.025.01.00.01.00.01.00.00.00.0
448.0150.04.025.01.00.01.00.01.00.00.00.0
.......................................
656375878.03300.02.025.00.01.00.00.00.00.01.00.0
656385878.01391.01.025.00.01.00.00.00.00.01.00.0
656395878.0185.04.025.00.01.00.00.00.00.01.00.0
656405878.02232.01.025.00.01.00.00.00.00.01.00.0
656415878.0426.03.025.00.01.00.00.00.00.01.00.0
\n", "

65642 rows × 12 columns

\n", "
" ], "text/plain": [ " user_id movie_id rating age Drama Sci-Fi gender_M \\\n", "0 48.0 1193.0 4.0 25.0 1.0 0.0 1.0 \n", "1 48.0 919.0 4.0 25.0 1.0 0.0 1.0 \n", "2 48.0 527.0 5.0 25.0 1.0 0.0 1.0 \n", "3 48.0 1721.0 4.0 25.0 1.0 0.0 1.0 \n", "4 48.0 150.0 4.0 25.0 1.0 0.0 1.0 \n", "... ... ... ... ... ... ... ... \n", "65637 5878.0 3300.0 2.0 25.0 0.0 1.0 0.0 \n", "65638 5878.0 1391.0 1.0 25.0 0.0 1.0 0.0 \n", "65639 5878.0 185.0 4.0 25.0 0.0 1.0 0.0 \n", "65640 5878.0 2232.0 1.0 25.0 0.0 1.0 0.0 \n", "65641 5878.0 426.0 3.0 25.0 0.0 1.0 0.0 \n", "\n", " occupation_academic/educator occupation_college/grad student \\\n", "0 0.0 1.0 \n", "1 0.0 1.0 \n", "2 0.0 1.0 \n", "3 0.0 1.0 \n", "4 0.0 1.0 \n", "... ... ... \n", "65637 0.0 0.0 \n", "65638 0.0 0.0 \n", "65639 0.0 0.0 \n", "65640 0.0 0.0 \n", "65641 0.0 0.0 \n", "\n", " occupation_executive/managerial occupation_other \\\n", "0 0.0 0.0 \n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 \n", "... ... ... \n", "65637 0.0 1.0 \n", "65638 0.0 1.0 \n", "65639 0.0 1.0 \n", "65640 0.0 1.0 \n", "65641 0.0 1.0 \n", "\n", " occupation_technician/engineer \n", "0 0.0 \n", "1 0.0 \n", "2 0.0 \n", "3 0.0 \n", "4 0.0 \n", "... ... \n", "65637 0.0 \n", "65638 0.0 \n", "65639 0.0 \n", "65640 0.0 \n", "65641 0.0 \n", "\n", "[65642 rows x 12 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the MovieLens data\n", "MovieLens_CEL = _env_getdata_CEL.get_movielens_CEL()\n", "MovieLens_CEL.pop(MovieLens_CEL.columns[0])\n", "MovieLens_CEL = MovieLens_CEL[MovieLens_CEL.columns.drop(['Comedy','Action', 'Thriller'])]\n", "MovieLens_CEL" ] }, { "cell_type": "code", "execution_count": 3, "id": "J__3Ozs7Uxxs", "metadata": { "id": "J__3Ozs7Uxxs" }, "outputs": [], "source": [ "n = len(MovieLens_CEL)\n", "userinfo_index = np.array([3,6,7,8,9,10,11])\n", "SandA = MovieLens_CEL.iloc[:, np.array([3,5,6,7,8,9,10,11])]" ] }, { "cell_type": "code", "execution_count": 4, "id": "X1VmlNjstdsN", "metadata": { "id": "X1VmlNjstdsN" }, "outputs": [], "source": [ "mu0 = GradientBoostingRegressor(max_depth=3)\n", "mu1 = GradientBoostingRegressor(max_depth=3)\n", "\n", "mu0.fit(MovieLens_CEL.iloc[np.where(MovieLens_CEL['Drama']==0)[0],userinfo_index],MovieLens_CEL.iloc[np.where(MovieLens_CEL['Drama']==0)[0],2] )\n", "mu1.fit(MovieLens_CEL.iloc[np.where(MovieLens_CEL['Drama']==1)[0],userinfo_index],MovieLens_CEL.iloc[np.where(MovieLens_CEL['Drama']==1)[0],2] )\n", "\n", "\n", "# estimate the HTE by T-learner\n", "HTE_T_learner = mu1.predict(MovieLens_CEL.iloc[:,userinfo_index]) - mu0.predict(MovieLens_CEL.iloc[:,userinfo_index])" ] }, { "cell_type": "markdown", "id": "FA-F8Jc_T5Lz", "metadata": { "id": "FA-F8Jc_T5Lz" }, "source": [ "Let's focus on the estimated HTEs for three randomly chosen users:" ] }, { "cell_type": "code", "execution_count": 5, "id": "GvHnTOxmT5Lz", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 318, "status": "ok", "timestamp": 1676750150517, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "GvHnTOxmT5Lz", "outputId": "7b0b76fd-f5ac-4ab8-a3c0-188e15484fe7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "T-learner: [0.3598282 0.34648075 0.35533324]\n" ] } ], "source": [ "print(\"T-learner: \",HTE_T_learner[np.array([0,1000,5000])])" ] }, { "cell_type": "code", "execution_count": 6, "id": "caa7ae93", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 0.3571 out of 5 points.\n" ] } ], "source": [ "ATE_T_learner = np.sum(HTE_T_learner)/n\n", "print(\"Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by\",round(ATE_T_learner,4), \"out of 5 points.\")" ] }, { "cell_type": "markdown", "id": "mVAZTZYTUKJ6", "metadata": { "id": "mVAZTZYTUKJ6" }, "source": [ "**Conclusion:** Same as the estimation result provided by S-learner, people are more inclined to give higher ratings to drama than science fictions. The expected causal effect estiamted by T-learner is larger than S-learner. In some cases when the treatment effect is relatively complex, it's likely to yield better performance by fitting two models separately. \n", "\n", "However, in an extreme case when both $\\mu_0(s)$ and $\\mu_1(s)$ are nonlinear complicated function of state $s$ while their difference is just a constant, T-learner will overfit each model very easily, yielding a nonlinear treatment effect estimator. In this case, other estimators are often preferred." ] }, { "cell_type": "markdown", "id": "nyirbjS5JdGh", "metadata": { "id": "nyirbjS5JdGh" }, "source": [ "## References\n", "1. Kunzel, S. R., Sekhon, J. S., Bickel, P. J., and Yu, B. (2019). Metalearners for estimating heterogeneous treatment effects using machine learning. Proceedings of the national academy of sciences 116, 4156–4165.\n" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }